import os
import sys
import scipy
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.stats import spearmanr, pearsonr
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%aimport
warnings.filterwarnings('ignore')
CYTOSELF_MODEL_PATH = '/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/models_outputs_deltaNLS_tl_neuroself_sep_TDP43/'
EMBEDDINGS_FOLDER = os.path.join(CYTOSELF_MODEL_PATH, 'embeddings', 'deltaNLS', 'vqindhist1')
SAVE_PATH = '/home/labs/hornsteinlab/Collaboration/MOmaps/outputs/figures/deltaNLS'
vqindhist, labels, paths = [] , [], []
for batch in [2,5]:
for dataset_type in ['trainset','valset','testset']:
cur_vqindhist, cur_labels, cur_paths = np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_{dataset_type}.npy")),\
np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_labels_{dataset_type}.npy")),\
np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_paths_{dataset_type}.npy"))
cur_vqindhist = cur_vqindhist.reshape(cur_vqindhist.shape[0], -1)
vqindhist.append(cur_vqindhist)
labels.append(cur_labels)
paths.append(cur_paths)
vqindhist = np.concatenate(vqindhist)
labels = np.concatenate(labels)
paths = np.concatenate(paths)
print(vqindhist.shape, labels.shape)
print(np.unique(labels).shape)
hist_df = pd.DataFrame(vqindhist)
hist_df['label'] = labels
hist_df['label'] = hist_df['label'].str.replace("_16bit_no_downsample", "")
hist_df['label'] = hist_df['label'].str.replace(os.sep, "_")
def rearrange_string(s):
parts = s.split('_')
return f"{parts[4]}_{parts[1]}_{parts[2]}_{parts[0]}_{parts[3]}"
hist_df['label'] = hist_df['label'].apply(rearrange_string)
hist_df_with_path = hist_df.copy()
hist_df_with_path['path'] = paths
hist_df
# print(hist_df.shape)
# hist_df = hist_df[hist_df.label.str.contains('WT')]
# print(hist_df.shape)
# hist_df
mean_spectra_per_marker = hist_df.groupby('label').mean()
mean_spectra_per_marker
corr = mean_spectra_per_marker.corr()
corr
kws = dict(cbar_kws=dict(ticks=[-1,0,1]))
clustermap = sns.clustermap(corr, center=0, cmap='bwr', vmin=-1, vmax=1, figsize=(9,5), xticklabels=False, **kws)
clustermap.ax_row_dendrogram.set_visible(False)
clustermap.ax_cbar.set_position([clustermap.ax_col_dendrogram.get_position().x1+0.01, # x location
clustermap.ax_col_dendrogram.get_position().y0+0.01, # y location
0.01, # width
clustermap.ax_col_dendrogram.get_position().height-0.05]) #height
clustermap.ax_cbar.set_title('Pearson r',fontsize=6)
clustermap.cax.tick_params(axis='y', labelsize=6, length=0, pad=0.1)
plt.show()
clustermap.figure.savefig(os.path.join(SAVE_PATH, "deltaNLS_codeword_idx_corr_heatmap.png"))
def get_clusters(clustermap, corr, cutoff = 14.2):
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(8,4))
den = scipy.cluster.hierarchy.dendrogram(clustermap.dendrogram_col.linkage,
labels = corr.index,
color_threshold=cutoff,
no_labels=True,
ax=axs[0])
axs[0].axhline(cutoff, c='black', linestyle="-")
#return den
def get_cluster_classes(den):
cluster_classes = defaultdict(list)
seen = []
cur_cluster = 1
last_color = den['leaves_color_list'][0]
for label, color in zip(den['ivl'], den['leaves_color_list']):
if color != last_color:
cur_cluster += 1
last_color = color
cluster_classes[cur_cluster].append(label)
return cluster_classes
clusters = get_cluster_classes(den)
cluster = []
corr_with_clusters = corr.copy()
for i in corr_with_clusters.index:
included=False
for j in clusters.keys():
if i in clusters[j]:
cluster.append(j)
included=True
if not included:
cluster.append(None)
corr_with_clusters["cluster"] = cluster
# visualize the cluster counts
sns.countplot(data=corr_with_clusters.sort_values(by='cluster'), x='cluster', palette='coolwarm', ax=axs[1])
# Add labels and title
axs[1].set_xlabel('Cluster')
axs[1].set_ylabel('Indices Count')
axs[1].set_title('Indices Counts per Cluster')
plt.tight_layout()
# Show
plt.show()
corr_with_clusters['cluster'] = corr_with_clusters['cluster'].astype(str)
corr_with_clusters['cluster'] = 'C' + corr_with_clusters['cluster']
return corr_with_clusters
corr_with_clusters = get_clusters(clustermap, corr, cutoff = 13)
#corr_with_clusters = get_clusters(clustermap, corr, cutoff = 14)
clusters = np.unique(corr_with_clusters.cluster)
hist_per_cluster = pd.DataFrame(index = hist_df_with_path.index, columns = list(clusters) + ['label','path'])
hist_per_cluster.label = hist_df_with_path.label
hist_per_cluster.path = hist_df_with_path.path
# for each cluster, get the indices and calc the sum of the histogram
for cluster_label, cluster_group in corr_with_clusters.groupby('cluster'):
hist_per_cluster[cluster_label] = hist_df_with_path[cluster_group.index].sum(axis=1) / 625
hist_per_cluster['max_cluster'] = hist_per_cluster.idxmax(axis=1, numeric_only=True)
fig, axs = plt.subplots(nrows=2*np.unique(hist_per_cluster.max_cluster).size, ncols=2, figsize=(4,32))
unique_markers = np.unique(hist_per_cluster.label.str.split("_").str[0])
unique_cell_lines = np.unique(hist_per_cluster.label.str.split("_").str[1])
unique_conditions = np.unique(hist_per_cluster.label.str.split("_").str[2])
color_light_green = '#8DF980'
color_gray = 'gray'
unique_label_per_clusters = {}
unique_marker_per_clusters = pd.DataFrame(color_gray, columns=clusters, index=unique_markers)
unique_cell_lines_per_clusters = pd.DataFrame(color_gray,columns=clusters, index=unique_cell_lines)
unique_conditions_per_clusters = pd.DataFrame(color_gray,columns=clusters, index=unique_conditions)
for i, (max_cluster, max_cluster_group) in enumerate(hist_per_cluster.groupby('max_cluster')):
max_cluster_group_thres = max_cluster_group[max_cluster_group[max_cluster] >= 0.5]
unique_label_per_clusters[max_cluster] = np.unique(max_cluster_group_thres.label)
unique_marker_per_clusters.loc[np.unique(max_cluster_group_thres.label.str.split("_").str[0]), max_cluster] = color_light_green
unique_cell_lines_per_clusters.loc[np.unique(max_cluster_group_thres.label.str.split("_").str[1]), max_cluster] = color_light_green
unique_conditions_per_clusters.loc[np.unique(max_cluster_group_thres.label.str.split("_").str[2]), max_cluster] = color_light_green
max_tiles_paths = max_cluster_group[[max_cluster,'path']].sort_values(by=max_cluster,ascending=False)[:4].path
for j, tile_path in enumerate(max_tiles_paths):
cut = tile_path.rfind("_")
real_path = tile_path[:cut]
tile_number = int(tile_path[cut+1:])
cur_site = np.load(real_path)
ax = axs[i * 2 + j // 2, j%2]
ax.imshow(cur_site[tile_number,:,:,0], cmap='gray')
ax.axis('off')
if j==0:
ax.text(-40,100, max_cluster, fontsize=15)
split_path=real_path.split(os.sep)
marker = split_path[-2]
condition = split_path[-3]
if 'Untreated' in condition:
condition = condition[:3]
cell_line = split_path[-4]
if 'FUS' in cell_line:
cell_line = cell_line[:6]
rep = split_path[-1].split("_")[0]
label = f"{cell_line}_{condition}_\n{marker}_{rep}"
ax.text(60,95,label, color='yellow', fontsize=6)
plt.subplots_adjust(wspace=0.01, hspace=0.01)
# Save the figure to file
plt.savefig(os.path.join(SAVE_PATH, "deltaNLS_representative_images_per_cluster.png"), bbox_inches='tight')
plt.show()
for df in [unique_marker_per_clusters,unique_conditions_per_clusters,unique_cell_lines_per_clusters]:
fig, ax = plt.subplots()
table = ax.table(rowLabels=df.index,
colLabels=df.columns,
cellLoc='center',
rowLoc='center',
loc='center',
cellColours=df.values)
plt.axis('off')
plt.show()
vqindhist_inference, labels_inference, paths_inference = [] , [], []
for batch in [3,4]:
for dataset_type in ['all']:
cur_vqindhist, cur_labels, cur_paths = np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_{dataset_type}.npy")),\
np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_labels_{dataset_type}.npy")),\
np.load(os.path.join(EMBEDDINGS_FOLDER, f"batch{batch}_16bit_no_downsample/vqindhist1_paths_{dataset_type}.npy"))
cur_vqindhist = cur_vqindhist.reshape(cur_vqindhist.shape[0], -1)
vqindhist_inference.append(cur_vqindhist)
labels_inference.append(cur_labels)
paths_inference.append(cur_paths)
# Add batches that were used in training as well
vqindhist_inference.append(vqindhist)
labels_inference.append(labels)
paths_inference.append(paths)
vqindhist_inference = np.concatenate(vqindhist_inference)
labels_inference = np.concatenate(labels_inference)
paths_inference = np.concatenate(paths_inference)
print(vqindhist_inference.shape, labels_inference.shape)
print(np.unique(labels_inference).shape)
# Generate DataFrame
hist_df_inference = pd.DataFrame(vqindhist_inference)
hist_df_inference['label'] = labels_inference
hist_df_inference['label'] = hist_df_inference['label'].str.replace("_16bit_no_downsample", "")
hist_df_inference['label'] = hist_df_inference['label'].str.replace(os.sep, "_")
def rearrange_string(s):
parts = s.split('_')
return f"{parts[4]}_{parts[1]}_{parts[2]}_{parts[0]}_{parts[3]}"
hist_df_inference['label'] = hist_df_inference['label'].apply(rearrange_string)
hist_df_inference
def plot_heatmap_with_clusters_and_histograms(corr_with_clusters, hist_df, labels,
sep = "_", colormap_name = "viridis",
filename="deltaNLS_plot_heatmap_with_clusters_and_histograms.png"):
# create the heatmap and dendrogram
kws = dict(cbar_kws=dict(ticks=[-1,0,1]))
clustermap = sns.clustermap(corr, center=0, cmap='bwr', vmin=-1, vmax=1, figsize=(9,5), xticklabels=False, yticklabels=False, col_colors=corr_with_clusters.cluster, **kws)
clustermap.ax_row_dendrogram.set_visible(False)
# get the indices order from the dendrogram
hierarchical_order = clustermap.dendrogram_col.reordered_ind
# prepare labels and filter histograms of wanted labels
real_labels = []
for label in labels:
if label not in np.unique(hist_df.label):
real_labels += [real_label for real_label in np.unique(hist_df.label) if label in real_label]
else:
real_labels.append(label)
hist_df_cur = hist_df[hist_df.label.isin(real_labels)]
cur_groups = real_labels
splitted_labels = hist_df_cur.label.str.split(sep)
cur_batches = np.unique(splitted_labels.str[-2])
cur_markers = np.unique(splitted_labels.str[-5])
cur_cell_lines = np.unique(splitted_labels.str[-4])
cur_conditions = np.unique(splitted_labels.str[-3])
cur_reps = np.unique(splitted_labels.str[-1])
# Mean the histograms by labels and re-order by the indices order
total_spectra_per_marker_ordered = hist_df_cur.groupby('label').mean()[hierarchical_order] #TODO: change to mean?
# calc clusters locations
#cluster_counts = pd.DataFrame(corr_with_clusters.cluster.value_counts()).reset_index().sort_values(by='cluster')
cluster_counts = pd.DataFrame(corr_with_clusters.cluster.value_counts()).reset_index()#
cluster_counts.cluster = cluster_counts.cluster.str.replace('C','').astype('int')
cluster_counts.sort_values(by='cluster', inplace=True)
cluster_positions = clustermap.ax_col_dendrogram.get_position()
num_samples = len(clustermap.dendrogram_col.data)
line_positions = [cluster_positions.x0 + i * (cluster_positions.width / num_samples) for i in range(1, num_samples)]
# make room for the histograms in the plot
hist_height = 0.05
clustermap.fig.subplots_adjust(top=hist_height*len(cur_groups)+1, bottom=hist_height*len(cur_groups))
# add axes for the histograms
axs=[]
for i, label in enumerate(cur_groups):
axs.append(clustermap.fig.add_axes([clustermap.ax_heatmap.get_position().x0, 0+i*hist_height, clustermap.ax_heatmap.get_position().width, hist_height]))
# create colors
colors = sns.color_palette(colormap_name, n_colors=len(cur_groups))
# plot the histograms
for i, label in enumerate(cur_groups[::-1]):
d = total_spectra_per_marker_ordered.loc[label, :]
axs[i].fill_between(range(len(d)), d, color=colors[i], label=label, linewidth=1)
axs[i].set_xticklabels([])
axs[i].set_xticks([])
axs[i].set_yticklabels([])
axs[i].set_yticks([])
axs[i].tick_params(axis='y', labelsize=4, length=0, pad=0.1)
splitted_label = label.split(sep)
label_for_plot = ''
if len(cur_cell_lines)>1:
label_for_plot+= f'{splitted_label[-4]}_'
if len(cur_conditions)>1:
label_for_plot+= f'{splitted_label[-3]}_'
if len(cur_markers)>1:
label_for_plot+= f'{splitted_label[-5]}_'
if len(cur_batches)>1:
label_for_plot+= f'{splitted_label[-2]}_'
if len(cur_reps)>1:
label_for_plot+= f'{splitted_label[-1]}'
if label_for_plot.endswith("_"):
label_for_plot = label_for_plot[:-1]
axs[i].text(1.02, 0.5, label_for_plot, transform=axs[i].transAxes,
rotation=0, va='center', ha='left')
# add cluster lines to histograms
prev_count = 0
for j, cluster in enumerate(cluster_counts.cluster):
cur_count = cluster_counts.iloc[j]['count']
cluster_end = cur_count + prev_count
axs[i].axvline(x=cluster_end, color='black',linestyle="--", linewidth=0.4)
prev_count = cluster_end
#ax.tick_params(axis='y', labelsize=8)
axs[i].spines['bottom'].set_color('lightgray')
axs[i].spines['top'].set_color('lightgray')
axs[i].spines['right'].set_color('lightgray')
axs[i].spines['left'].set_color('lightgray')
axs[i].margins(x=0)
# fix the cbar appearance
clustermap.ax_cbar.set_position([clustermap.ax_col_dendrogram.get_position().x1+0.01, # x location
clustermap.ax_col_dendrogram.get_position().y0+0.01, # y location
0.01, # width
clustermap.ax_col_dendrogram.get_position().height-0.05]) #height
clustermap.ax_cbar.set_title('Pearson r',fontsize=6)
clustermap.cax.tick_params(axis='y', labelsize=6, length=0, pad=0.1)
# add cluster lines to the heatmap
prev_count = 0
for j, cluster in enumerate(cluster_counts.cluster):
cur_count = cluster_counts.iloc[j]['count']
cluster_end = cur_count + prev_count
clustermap.ax_heatmap.axvline(x=cluster_end, color='black',linestyle="--", linewidth=0.4)
clustermap.ax_col_colors.text(x=cluster_end-(cur_count/2), y=0.5, s=cluster, fontsize=6)
prev_count = cluster_end
clustermap.figure.savefig(os.path.join(SAVE_PATH, filename), bbox_inches='tight')
return None
plot_heatmap_with_clusters_and_histograms(corr_with_clusters, hist_df_inference, labels=['_'],
filename="deltaNLS_plot_heatmap_with_clusters_and_histograms.png")
# Merge batches
hist_df_inference_merged = hist_df_inference.copy()
hist_df_inference_merged['label'] = hist_df_inference_merged['label'].str.split("_").str[0:3].apply(lambda x: '_'.join(x))
plot_heatmap_with_clusters_and_histograms(corr_with_clusters,
hist_df_inference_merged,
labels=['_'],
filename="deltaNLS_plot_heatmap_with_clusters_and_histograms_merged_batches.png")
# batches_inference = hist_df_inference[hist_df_inference.label.str.contains(r'batch3|batch4', regex=True)]
# batches_inference = batches_inference[batches_inference.label.str.contains('WT')]
# batches_inference = batches_inference.copy()
# batches_inference['label'] = batches_inference['label'].str.split("_").str[0:3].apply(lambda x: '_'.join(x))
# batches_test_agg
# plot_heatmap_with_clusters_and_histograms(corr_with_clusters, batches_test_agg, labels = ['FMRP','PML','G3BP1',
# 'PURA','TOMM20','SQSTM1','mitotracker',
# 'TDP43','PSD95','DCP1A'])
print("Done!")